"""
检测代码
"""
import torch.nn

from fashion_mnist import *
import toolpy

#  检测单轮次数据

""" 使用自己创造的网络 """
batch_size = 20
train_iter, test_iter = load_data_fashion_mnist(batch_size)
# net = Net1()
# loss = cross_entropy
# updater = net.updater
#
# timer = toolpy.Timer()
# train_ch3(net, train_iter, test_iter, 10, loss, updater)
#
# print("花时："+str(timer.stop())+"sec")

net = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(784, 10))
loss = torch.nn.CrossEntropyLoss(reduction='none')
updater = torch.optim.SGD(net.parameters(), lr=0.1)
timer = toolpy.Timer()

train_ch3(net, train_iter, test_iter, 10, loss, updater)
predict_ch3(net, test_iter)

print("花时："+str(timer.stop())+"sec")
